#Determine the hyperparameters for the model
import pandas as pd
import Bio.SeqIO
import numpy as np
import itertools
from keras.layers import Dense,Activation,Flatten,Dropout,Conv2D, MaxPooling2D
from keras.models import Sequential
from keras.utils import np_utils
from keras.callbacks import EarlyStopping
from keras import backend as K
from sklearn.metrics import confusion_matrix

SEQ_FILE_PRO='promoters.fa'
SEQ_FILE_TER='terminators.fa'
DATA_FILE='un_vs_expr.csv'

data=pd.read_csv(DATA_FILE,sep='\t')
data=data.set_index('Geneid')

# empty lists 
geneIDs=[]
categories=[]

seqs=[]
seqs_pro=[]
seqs_ter=[]

# read sequences 
for x,y in zip(Bio.SeqIO.parse(SEQ_FILE_PRO,'fasta'), Bio.SeqIO.parse(SEQ_FILE_TER,'fasta') ):
     geneID=x.id
     seq_pro=str(x.seq)
     seq_ter=str(y.seq)
     seq=seq_pro+seq_ter
     if geneID in data.index:
         seqs.append(seq)
         seqs_pro.append(seq_pro)
         seqs_ter.append(seq_ter)
         geneIDs.append(geneID)
         category=data['category'][geneID]
         if category=='expressed':
              categories.append(1)
         else:
              categories.append(0)

# one-hot encoding
dict={'A':[1,0,0,0],'C':[0,1,0,0],'G':[0,0,1,0],'T':[0,0,0,1]}
# one-hot encoding for methylation sites
# dict={'X':[1,0,0,0],'Y':[0,1,0,0],'Z':[0,0,1,0],'C':[0,0,0,1],'N':[0,0,0,0],'A':[0,0,0,0],'T':[0,0,0,0],'G':[0,0,0,0]}

def one_hot_encoding(seq):
    one_hot_encoded=np.zeros(shape=(4,len(seq)))
    for i,nt in enumerate(seq):
        one_hot_encoded[:,i]=dict[nt]
    return one_hot_encoded

one_hot_seqs=np.expand_dims(np.array([ one_hot_encoding(seq) for seq in seqs],dtype=np.float64),3)

# conceal start and stop codons from the CNN
one_hot_seqs[:,:,1000:1003,:]=0
one_hot_seqs[:,:,1997:2000,:]=0

geneIDs=np.array(geneIDs)
categories=np_utils.to_categorical(np.array(categories),2)

# hyperparameters to test
hyperparams = {
'conv1_filters': [64,128],
'conv2_filters': [64,128],
'conv3_filters': [64,128],
'conv_width':    [4,8],
'pool_width':    [4,8],
'pool_stride':   [4,8],
'dropout':       [0.1,0.25],
'dense1_units':  [64,128],
'dense2_units':  [64,128],
'conv_layers':   [1,2,3],
'dense_layers':  [2,3]
}

# A list of all combinations of hyperparameters
hps=[[conv1_filters,conv2_filters,conv3_filters,conv_width,pool_width,pool_stride,dropout,dense1_units,dense2_units,conv_layers,dense_layers] \
for conv1_filters in hyperparams['conv1_filters'] \
for conv2_filters in hyperparams['conv2_filters'] \
for conv3_filters in hyperparams['conv3_filters'] \
for conv_width    in hyperparams['conv_width'] \
for pool_width    in hyperparams['pool_width'] \
for pool_stride   in hyperparams['pool_stride'] \
for dropout       in hyperparams['dropout'] \
for dense1_units  in hyperparams['dense1_units'] \
for dense2_units  in hyperparams['dense2_units'] \
for conv_layers   in hyperparams['conv_layers'] \
for dense_layers  in hyperparams['dense_layers'] \
]

# When some layers do not exist, their parameters should not exist too.
for x in hps:
   if(x[9]==1):
       x[1]=None
       x[2]=None
   if(x[9]==2):
       x[2]=None
   if(x[10]==2):
       x[8]=None
hps.sort
hps = list(hps for hps,_ in itertools.groupby(hps))

# Build neural network
def build(DNA_length,hp):
                conv1_filters,conv2_filters,conv3_filters,conv_width,pool_width,pool_stride,dropout,dense1_units,dense2_units,conv_layers,dense_layers=hp
                model=Sequential()

                model.add(Conv2D(conv1_filters,kernel_size=(4,conv_width),padding='valid',input_shape=[4,DNA_length,1]))
                model.add(Activation('relu'))
                model.add(Conv2D(conv1_filters,kernel_size=(1,conv_width),padding='same'))
                model.add(Activation('relu'))
                model.add(MaxPooling2D(pool_size=(1,pool_width),strides=(1,pool_stride),padding='same'))
                model.add(Dropout(dropout))
                
                if conv_layers>=2:
                   model.add(Conv2D(conv2_filters,kernel_size=(1,conv_width),padding='same'))
                   model.add(Activation('relu'))
                   model.add(Conv2D(conv2_filters,kernel_size=(1,conv_width),padding='same'))
                   model.add(Activation('relu'))
                   model.add(MaxPooling2D(pool_size=(1,pool_width),strides=(1,pool_stride),padding='same'))
                   model.add(Dropout(dropout))

                if conv_layers>=3:
                   model.add(Conv2D(conv3_filters,kernel_size=(1,conv_width),padding='same'))
                   model.add(Activation('relu'))
                   model.add(Conv2D(conv3_filters,kernel_size=(1,conv_width),padding='same'))
                   model.add(Activation('relu'))
                   model.add(MaxPooling2D(pool_size=(1,pool_width),strides=(1,pool_stride),padding='same'))
                   model.add(Dropout(dropout))

                model.add(Flatten())
                model.add(Dense(dense1_units))
                model.add(Activation('relu'))
                model.add(Dropout(dropout))
 
                if dense_layers>=3:
                   model.add(Dense(dense2_units))
                   model.add(Activation('relu'))

                model.add(Dense(2))
                model.add(Activation('softmax'))
                return model

# train models for different combinations of parameters
def compare_models(TRAINING_ONE_HOT_SEQS,TESTING_ONE_HOT_SEQS,training_indices,testing_indices,split,test_subsample,shuffle,predictor):
      #splitting into training set and test set
      TRAINING_ONE_HOT_SEQS=TRAINING_ONE_HOT_SEQS[training_indices]
      TESTING_ONE_HOT_SEQS=TESTING_ONE_HOT_SEQS[testing_indices]
      TRAINING_CATEGORIES=categories[training_indices]
      TESTING_CATEGORIES=categories[testing_indices]
      TESTING_GENES=geneIDs[testing_indices]

      LENGTH=TRAINING_ONE_HOT_SEQS.shape[2]
      callbacks = [EarlyStopping(monitor='loss', patience=3, verbose=0)]

      for hp in hps:
         model=build(LENGTH,hp)
         model.compile(loss='binary_crossentropy',optimizer='Adam',metrics=['accuracy'])
         model.fit(x=TRAINING_ONE_HOT_SEQS,y=TRAINING_CATEGORIES,batch_size=256,epochs=40,validation_data=(TESTING_ONE_HOT_SEQS,TESTING_CATEGORIES),callbacks=callbacks)

         prediction=model.predict(TESTING_ONE_HOT_SEQS)
         predicted_categories= np.argmax(prediction,axis=1)
         real_categories=np.argmax(TESTING_CATEGORIES,axis=1)
         tn,fp,fn,tp=confusion_matrix(real_categories,predicted_categories).ravel()
         test_total=tp+tn+fp+fn
         accuracy=(tp+tn)/(test_total*1.0)

         with open('compare_models', "a") as result_file:
            result_file.write("\t".join([str(x) for x in hp]))
            result_file.write("\t")
            result_file.write("\t".join([split,test_subsample,shuffle,predictor,str(accuracy)]))
            result_file.write("\n")
         del model
         K.clear_session()

# Write the header line to file
with open('compare_models', "a") as result_file:
            result_file.write("\t".join(['conv1_filters','conv2_filters','conv3_filters','conv_width','pool_width','pool_stride','dropout','dense1_units','dense2_units','conv_layers','dense_layers']))
            result_file.write("\t")
            result_file.write("\t".join(['split','test_subsample','shuffle','predictor','accuracy']))
            result_file.write("\n")

# Start training models
splits=data.columns.values[-10:]
subsamples=['subsample4','subsample5']

for split in splits:
   for test_subsample in subsamples:
       testing_genes = data[data[split]==test_subsample].index
       testing_genes = np.array(list(set(testing_genes).intersection(set(geneIDs))))
       # In testing set, downsample expressed genes to make them balanced with unexpressed genes
       unexpressed_testing_genes=[gene for gene in testing_genes if data['category'][gene]=='unexpressed'  ]
       expressed_testing_genes=[gene for gene in testing_genes if data['category'][gene]=='expressed']
       expressed_testing_genes=np.random.choice(expressed_testing_genes,len(unexpressed_testing_genes),replace=False)
       testing_genes=np.concatenate((expressed_testing_genes,unexpressed_testing_genes),axis=0)

       training_genes= data[data[split]!=test_subsample].index
       training_genes= np.array(list(set(training_genes).intersection(set(geneIDs))))
       # In training set, downsample expressed genes to make them balanced with unexpressed genes
       unexpressed_training_genes=[gene for gene in training_genes if data['category'][gene]=='unexpressed'  ]
       expressed_training_genes=[gene for gene in training_genes if data['category'][gene]=='expressed']
       expressed_training_genes=np.random.choice(expressed_training_genes,len(unexpressed_training_genes),replace=False)
       training_genes=np.concatenate((expressed_training_genes,unexpressed_training_genes),axis=0)

       training_indices=(np.array([np.where(geneIDs==element)[0][0] for element in training_genes])).astype('int64')
       testing_indices=(np.array([np.where(geneIDs==element)[0][0] for element in testing_genes])).astype('int64')
       np.random.shuffle(training_indices)
       compare_models(one_hot_seqs,one_hot_seqs,training_indices,testing_indices,split,test_subsample,'None','Pro_and_Ter')
